{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gỡ lỗi quy trình huấn luyện"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ValueError: No gradients provided for any variable: ['tf_distil_bert_for_sequence_classification/distilbert/embeddings/word_embeddings/weight:0', '...']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "import evaluate\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    TFAutoModelForSequenceClassification,\n",
    ")\n",
    "\n",
    "raw_datasets = load_dataset(\"glue\", \"mnli\")\n",
    "\n",
    "model_checkpoint = \"distilbert-base-uncased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    return tokenizer(examples[\"premise\"], examples[\"hypothesis\"], truncation=True)\n",
    "\n",
    "\n",
    "tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)\n",
    "\n",
    "train_dataset = tokenized_datasets[\"train\"].to_tf_dataset(\n",
    "    columns=[\"input_ids\", \"labels\"], batch_size=16, shuffle=True\n",
    ")\n",
    "\n",
    "validation_dataset = tokenized_datasets[\"validation_matched\"].to_tf_dataset(\n",
    "    columns=[\"input_ids\", \"labels\"], batch_size=16, shuffle=True\n",
    ")\n",
    "\n",
    "model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint)\n",
    "\n",
    "model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n",
    "\n",
    "model.fit(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention_mask': <tf.Tensor: shape=(16, 76), dtype=int64, numpy=\n",
       " array([[1, 1, 1, ..., 0, 0, 0],\n",
       "        [1, 1, 1, ..., 0, 0, 0],\n",
       "        [1, 1, 1, ..., 0, 0, 0],\n",
       "        ...,\n",
       "        [1, 1, 1, ..., 1, 1, 1],\n",
       "        [1, 1, 1, ..., 0, 0, 0],\n",
       "        [1, 1, 1, ..., 0, 0, 0]])>,\n",
       " 'label': <tf.Tensor: shape=(16,), dtype=int64, numpy=array([0, 2, 1, 2, 1, 1, 2, 0, 0, 0, 1, 0, 1, 2, 2, 1])>,\n",
       " 'input_ids': <tf.Tensor: shape=(16, 76), dtype=int64, numpy=\n",
       " array([[ 101, 2174, 1010, ...,    0,    0,    0],\n",
       "        [ 101, 3174, 2420, ...,    0,    0,    0],\n",
       "        [ 101, 2044, 2048, ...,    0,    0,    0],\n",
       "        ...,\n",
       "        [ 101, 3398, 3398, ..., 2051, 2894,  102],\n",
       "        [ 101, 1996, 4124, ...,    0,    0,    0],\n",
       "        [ 101, 1999, 2070, ...,    0,    0,    0]])>}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for batch in train_dataset:\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "  246/24543 [..............................] - ETA: 15:52 - loss: nan"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.compile(optimizer=\"adam\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TFSequenceClassifierOutput(loss=<tf.Tensor: shape=(16,), dtype=float32, numpy=\n",
       "array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,\n",
       "       nan, nan, nan], dtype=float32)>, logits=<tf.Tensor: shape=(16, 2), dtype=float32, numpy=\n",
       "array([[nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan],\n",
       "       [nan, nan]], dtype=float32)>, hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TFSequenceClassifierOutput(loss=<tf.Tensor: shape=(16,), dtype=float32, numpy=\n",
       "array([0.6844486 ,        nan,        nan, 0.67127866, 0.7068601 ,\n",
       "              nan, 0.69309855,        nan, 0.65531296,        nan,\n",
       "              nan,        nan, 0.675402  ,        nan,        nan,\n",
       "       0.69831556], dtype=float32)>, logits=<tf.Tensor: shape=(16, 2), dtype=float32, numpy=\n",
       "array([[-0.04761693, -0.06509043],\n",
       "       [-0.0481936 , -0.04556257],\n",
       "       [-0.0040929 , -0.05848458],\n",
       "       [-0.02417453, -0.0684005 ],\n",
       "       [-0.02517801, -0.05241832],\n",
       "       [-0.04514256, -0.0757378 ],\n",
       "       [-0.02656011, -0.02646275],\n",
       "       [ 0.00766164, -0.04350497],\n",
       "       [ 0.02060014, -0.05655622],\n",
       "       [-0.02615328, -0.0447021 ],\n",
       "       [-0.05119278, -0.06928903],\n",
       "       [-0.02859691, -0.04879177],\n",
       "       [-0.02210129, -0.05791225],\n",
       "       [-0.02363213, -0.05962167],\n",
       "       [-0.05352269, -0.0481673 ],\n",
       "       [-0.08141848, -0.07110836]], dtype=float32)>, hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint)\n",
    "model(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 1,  2,  5,  7,  9, 10, 11, 13, 14])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "loss = model(batch).loss.numpy()\n",
    "indices = np.flatnonzero(np.isnan(loss))\n",
    "indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  101,  2007,  2032,  2001,  1037, 16480,  3917,  2594,  4135,\n",
       "        23212,  3070,  2214, 10170,  1010,  2012,  4356,  1997,  3183,\n",
       "         6838, 12953,  2039,  2000,  1996,  6147,  1997,  2010,  2606,\n",
       "         1012,   102,  6838,  2001,  3294,  6625,  3773,  1996,  2214,\n",
       "         2158,  1012,   102,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0],\n",
       "       [  101,  1998,  6814,  2016,  2234,  2461,  2153,  1998, 13322,\n",
       "         2009,  1012,   102,  2045,  1005,  1055,  2053,  3382,  2008,\n",
       "         2016,  1005,  2222,  3046,  8103,  2075,  2009,  2153,  1012,\n",
       "          102,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0],\n",
       "       [  101,  1998,  2007,  1996,  3712,  4634,  1010,  2057,  8108,\n",
       "         2025,  3404,  2028,  1012,  1996,  2616, 18449,  2125,  1999,\n",
       "         1037,  9666,  1997,  4100,  8663, 11020,  6313,  2791,  1998,\n",
       "         2431,  1011,  4301,  1012,   102,  2028,  1005,  1055,  5177,\n",
       "         2110,  1998,  3977,  2000,  2832,  2106,  2025,  2689,  2104,\n",
       "         2122,  6214,  1012,   102,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0],\n",
       "       [  101,  1045,  2001,  1999,  1037, 13090,  5948,  2007,  2048,\n",
       "         2308,  2006,  2026,  5001,  2043,  2026,  2171,  2001,  2170,\n",
       "         1012,   102,  1045,  2001,  3564,  1999,  2277,  1012,   102,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0],\n",
       "       [  101,  2195,  4279,  2191,  2039,  1996,  2181,  2124,  2004,\n",
       "         1996,  2225,  7363,  1012,   102,  2045,  2003,  2069,  2028,\n",
       "         2451,  1999,  1996,  2225,  7363,  1012,   102,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0],\n",
       "       [  101,  2061,  2008,  1045,  2123,  1005,  1056,  2113,  2065,\n",
       "         2009,  2428, 10654,  7347,  2030,  2009,  7126,  2256,  2495,\n",
       "         2291,   102,  2009,  2003,  5094,  2256,  2495,  2291,  2035,\n",
       "         2105,  1012,   102,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0],\n",
       "       [  101,  2051,  1010,  2029,  3216,  2019,  2503,  3444,  1010,\n",
       "         6732,  1996,  2265,  2038, 19840,  2098,  2125,  9906,  1998,\n",
       "         2003,  2770,  2041,  1997,  4784,  1012,   102,  2051,  6732,\n",
       "         1996,  2265,  2003,  9525,  1998,  4569,  1012,   102,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0],\n",
       "       [  101,  1996, 10556,  2140, 11515,  2058,  1010,  2010,  2162,\n",
       "         2252,  5689,  2013,  2010,  7223,  1012,   102,  2043,  1996,\n",
       "        10556,  2140, 11515,  2058,  1010,  2010,  2252,  3062,  2000,\n",
       "         1996,  2598,  1012,   102,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0],\n",
       "       [  101, 13543,  1999,  2049,  6143,  2933,  2443,   102,  2025,\n",
       "        13543,  1999,  6143,  2933,  2003,  2443,   102,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids = batch[\"input_ids\"].numpy()\n",
    "input_ids[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.config.num_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.optimizers import Adam\n",
    "\n",
    "model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint)\n",
    "model.compile(optimizer=Adam(5e-5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "319/24543 [..............................] - ETA: 16:07 - loss: 0.9718"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = batch[\"input_ids\"].numpy()\n",
    "tokenizer.decode(input_ids[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = batch[\"labels\"].numpy()\n",
    "label = labels[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for batch in train_dataset:\n",
    "    break\n",
    "\n",
    "# Đảm bảo rằng bạn đã chạy model.compile() và đặt trình tối ưu hóa của mình,\n",
    "# và mất mát/chỉ số của bạn nếu bạn đang sử dụng chúng\n",
    "\n",
    "model.fit(batch, epochs=20)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Gỡ lỗi quy trình huấn luyện",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
